{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial \n",
    "\n",
    "The release of Deepseek R1 shocked the industry. Why? Well, DeepSeek-R1 is an open model that rivals OpenAI's o1 in complex reasoning tasks, introduced using Group Relative Policy Optimization (GRPO) and RL-focused multi-stage training approach. They not only released the model, but also a research paper on how they did it. \n",
    "\n",
    "In the paper they described an \"aha moment\" when using pure RL to train the model. During this phase, DeepSeek-R1-Zero (the first test of DeepSeek-R1) learns to allocate more thinking time to a problem by reevaluating its initial approach without any human feedback or data describing how to do it.  They describe this as an \"aha moment\" as:\n",
    " \n",
    "> This behavior is not only a testament to the model’s growing reasoning abilities but also a captivating example of how reinforcement learning can lead to unexpected and sophisticated outcomes.\n",
    "\n",
    "It serves as a powerful reminder of the potential of RL to unlock new levels of intelligence in artificial systems, paving the way for more autonomous and adaptive models in the future. In this blog post we want to recreate the small \"aha moment\" of DeepSeek-R1 using Group Relative Policy Optimization (GRPO) and the Countdown Game. We will train an open model using reinforcement learning trying to teach it self-verification and search abilities all on its own to solve the Countdown Game. The Countdown game is a numbers puzzle where players use a set of randomly drawn numbers and basic arithmetic operations (+, -, ×, ÷) to reach or get as close as possible to a target number.\n",
    "\n",
    "```\n",
    "Target Number: 952\n",
    "Available Numbers: 25, 50, 75, 100, 3, 6\n",
    "\n",
    "(100 × (3 × 3)) + (50 + 6 / 3) = 952\n",
    "```\n",
    "\n",
    "The blog post includes an interactive code which you can run in a Jupyter Notebook on how to train a model using GRPO and Q-Lora. This is a great way to learn how to use TRL and GRPO, but it is very slow and requires a lot of compute. Additionally, I added a [script](https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/scripts/run_r1_grpo.py) and instructions to run the training on Node with multiple GPUs or a SLURM cluster. \n",
    "\n",
    "1. [Setup the development environment](#1-setup-the-development-environment)\n",
    "2. [Generate training samples with reasoning prefix from the Countdown Game](#2-generate-training-samples-with-reasoning-prefix-from-the-countdown-game)\n",
    "3. [Train the model using GRPO](#3-train-the-model-using-grpo)\n",
    "4. [Distributed Training example for GRPO using Deepspeed and vLLM](#4-distributed-training-example-for-grpo-using-deepspeed-and-vllm)\n",
    "5. [Results and Training Observations](#5-results-and-training-observations)\n",
    "\n",
    "_Note: This blog is inspired by [Jiayi Pan](https://x.com/jiayi_pirate/status/1882839370505621655) who initially explored the idea and proofed it with a small model._\n",
    "\n",
    "But Before we start, let's take a look at the [Group Relative Policy Optimization (GRPO)](https://arxiv.org/abs/2402.03300) and understand how it works.\n",
    "\n",
    "## Group Relative Policy Optimization (GRPO)\n",
    "\n",
    "Group Relative Policy Optimization (GRPO) is a reinforcement learning algorithm to improve the reasoning capabilities of LLMs. It was introduced in the [DeepSeekMath](https://arxiv.org/abs/2402.03300) paper in the context of mathematical reasoning. GRPO modifies the traditional Proximal Policy Optimization (PPO) by eliminating the need for a value function model. Instead, it estimates baselines from group scores, reducing memory usage and computational overhead. GRPO, now also used by the Qwen team, can be used with rule/binary-based Rewards as well as General Reward Models to improve models on helpfulness. \n",
    "\n",
    "1. **Sampling**: Generate multiple outputs for each prompt using the current policy\n",
    "2. **Reward Scoring**: Each generation is scored using a reward function, could be (rule-based or outcome-based)\n",
    "3. **Advantage Calculation**: The average reward of the generated outputs is used as a baseline. The advantage of each solution within the group is then computed relative to this baseline. The reward is normalized within a group.\n",
    "4. **Policy Optimization**: The policy tries to maximize the GRPO objective, which includes the calculated advantages and a KL divergence term. This is different from how PPO implements the KL term within the reward.\n",
    "\n",
    "![grpo.png](../assets/grpo.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup development environment\n",
    "\n",
    "Our first step is to install Hugging Face Libraries and Pytorch, vllm, and trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install Pytorch & other libraries, make sure to match your GPU driver version\n",
    "%pip install \"torch==2.5.1\" tensorboard \"setuptools<71.0.0\"  --index-url https://download.pytorch.org/whl/cu121\n",
    "\n",
    "# Install flash-attn\n",
    "%pip install flash-attn \n",
    "\n",
    "# Install Hugging Face libraries\n",
    "%pip install  --upgrade \\\n",
    "  \"transformers==4.48.1\" \\\n",
    "  \"datasets==3.1.0\" \\\n",
    "  \"accelerate==1.3.0\" \\\n",
    "  \"hf-transfer==0.1.9\" \\\n",
    "  \"deepspeed==0.15.4\" \\\n",
    "  \"trl==0.14.0\"\n",
    "\n",
    "# install vLLM \n",
    "%pip install \"vllm==0.7.0\"\n",
    "\n",
    "## IMPORTANT: If you want to run the notebook and the interactive cells you also need to install the following libraries:\n",
    "# But first read it the blog post and then decide as they might conflict with the libraries for distributed training. \n",
    "# %pip install \"peft==0.14.0\" \"bitsandbytes==0.45.0\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_Note: you may need to restart the kernel to use updated packages._\n",
    "\n",
    "We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training. You must register on the [Hugging Face](https://huggingface.co/join) for this. After you have an account, we will use the `login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "\n",
    "login(token=\"\", add_to_git_credential=True) # ADD YOUR TOKEN HERE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Generate training samples with reasoning prefix from the Countdown Game\n",
    "\n",
    "We are going to use the [Jiayi-Pan/Countdown-Tasks-3to4](https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4) dataset, which contains samples with 3 to 4 numbers and solutions.\n",
    "\n",
    "As Model we are going to use [Qwen/Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) which is a 3B parameter instruction tuned model. This makes it easier to showcase the \"aha moment\" as it already follows the prompt format. But you can use the base version of Qwen or other models as well. [Jiayi-Pan](https://x.com/jiayi_pirate/status/1882839487417561307) explored that the model needs to have a certain quality to be able to learn the reasoning process, starting with > 1.5B parameters. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Load dataset from Hugging Face Hub\n",
    "dataset_id = \"Jiayi-Pan/Countdown-Tasks-3to4\"\n",
    "dataset = load_dataset(dataset_id, split=\"train\")\n",
    "# select a random subset of 50k samples\n",
    "dataset = dataset.shuffle(seed=42).select(range(50000))\n",
    "\n",
    "# Load tokenizer from Hugging Face Hub to format the dataset to our \"r1\" prompt \n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-3B-Instruct\")\n",
    "\n",
    "# gemerate r1 prompt with a prefix for the model to already start with the thinking process\n",
    "def generate_r1_prompt(numbers, target):\n",
    "    r1_prefix = [{\n",
    "        \"role\": \"system\",\n",
    "        \"content\": \"You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.\"\n",
    "      },\n",
    "      { \n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 = 1 </answer>.\"\n",
    "      },\n",
    "      {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": \"Let me solve this step by step.\\n<think>\"\n",
    "      }]\n",
    "    return {\"prompt\": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), \"target\": target}\n",
    "\n",
    "# convert our dataset to the r1 prompt\n",
    "dataset = dataset.map(lambda x: generate_r1_prompt(x[\"nums\"], x[\"target\"]))\n",
    "\n",
    "# split the dataset into train and test\n",
    "train_test_split = dataset.train_test_split(test_size=0.1)\n",
    "\n",
    "train_dataset = train_test_split[\"train\"]\n",
    "test_dataset = train_test_split[\"test\"]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Train the model using GRPO\n",
    "\n",
    "_Note: Section 3 is shows the basic on how to use TRL and GRPO. If you want to run the interactive cells you need to install `bitsandbytes` and `peft` as they are required for the `Trainer` class. This section is mostly for educational purposes._\n",
    "\n",
    "TRL supports Group Relative Policy Optimization (GRPO) through a dedicated [GRPOTrainer](https://huggingface.co/docs/trl/main/en/grpo_trainer) for aligning LLMs from preference data, as described in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/abs/2402.03300). The `GRPOTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, checkpointing, distributed training, and parameter efficient fine-tuning (PEFT). \n",
    "\n",
    "The `GRPOTrainer` supports generic Outcome Reward Models (ORM) and custom reward functions, that can be used to implement Rule-Based Reward Models. In the Deepseek R1 paper they implemented Rule-Based Reward Models to verify the correctness of the generated solutions. In our exmaple we are going to do a similar approach, where we will create 2 reward functions that: \n",
    "1. Format Reward: Checks if the generated format is correct `<think> [thinking] </think><answer> [answer] </answer>`\n",
    "2. Accuracy Reward: Extracts the equation from the `<answer>` tag and evaluates it against the target and if every number is used once.\n",
    "\n",
    "_Note: Correct `<answer>` in our example includes the equation, for example `<answer> 55 + 36 - 7 - 19 </answer>`_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def format_reward_func(completions, target, **kwargs):\n",
    "    \"\"\"\n",
    "    Format: <think>...</think><answer>...</answer>\n",
    "    Args:\n",
    "        completions (list[str]): Generated outputs\n",
    "        target (list[str]): Expected answers\n",
    "      \n",
    "      Returns:\n",
    "          list[float]: Reward scores\n",
    "    \"\"\"\n",
    "    rewards = []\n",
    "\n",
    "    for completion, gt in zip(completions, target):\n",
    "\n",
    "      try:\n",
    "        # add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex\n",
    "        completion = \"<think>\" + completion        \n",
    "        # Check if the format is correct\n",
    "        regex = r\"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\\/think>\\n<answer>([\\s\\S]*?)<\\/answer>$\"\n",
    "\n",
    "        match = re.search(regex, completion, re.DOTALL) \n",
    "        # if the format is not correct, reward is 0\n",
    "        if match is None or len(match.groups()) != 2:\n",
    "            rewards.append(0.0)\n",
    "        else:\n",
    "            rewards.append(1.0)\n",
    "      except Exception:\n",
    "        rewards.append(0.0)\n",
    "    return rewards\n",
    "\n",
    "def equation_reward_func(completions, target, nums, **kwargs):\n",
    "    \"\"\"\n",
    "    Evaluates completions based on:\n",
    "    2. Mathematical correctness of the answer\n",
    "\n",
    "    Args:\n",
    "        completions (list[str]): Generated outputs\n",
    "        target (list[str]): Expected answers\n",
    "        nums (list[str]): Available numbers\n",
    "    \n",
    "    Returns:\n",
    "        list[float]: Reward scores\n",
    "    \"\"\"\n",
    "    rewards = []\n",
    "    for completion, gt, numbers in zip(completions, target, nums):\n",
    "      try:\n",
    "        # add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex\n",
    "        completion = \"<think>\" + completion\n",
    "        # Check if the format is correct\n",
    "        match = re.search(r\"<answer>(.*?)<\\/answer>\", completion)\n",
    "        if match is None:\n",
    "            rewards.append(0.0)\n",
    "            continue\n",
    "        # Extract the \"answer\" part from the completion\n",
    "        equation = match.group(1).strip()\n",
    "        # Extract all numbers from the equation\n",
    "        used_numbers = [int(n) for n in re.findall(r'\\d+', equation)]\n",
    "        \n",
    "        # Check if all numbers are used exactly once\n",
    "        if sorted(used_numbers) != sorted(numbers):\n",
    "            rewards.append(0.0)\n",
    "            continue\n",
    "        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace\n",
    "        allowed_pattern = r'^[\\d+\\-*/().\\s]+$'\n",
    "        if not re.match(allowed_pattern, equation):\n",
    "           rewards.append(0.0)\n",
    "           continue\n",
    "        \n",
    "        # Evaluate the equation with restricted globals and locals\n",
    "        result = eval(equation, {\"__builti'ns__\": None}, {})\n",
    "        # Check if the equation is correct and matches the ground truth\n",
    "        if abs(float(result) - float(gt)) < 1e-5:\n",
    "            rewards.append(1.0)\n",
    "        else:\n",
    "            rewards.append(0.0)\n",
    "      except Exception:\n",
    "            # If evaluation fails, reward is 0\n",
    "            rewards.append(0.0) \n",
    "    return rewards"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets try our reward function with a sample. \n",
    "\n",
    "_Note: None of the example starts with `<think>` as we added it synthetically to the prompt._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_sample_1 = \"\"\"We need to find an equation using the numbers 19, 36, 55, and 7\n",
    "exactly once, with basic arithmetic operations, that equals 65. One possible\n",
    "combination is 55 + 36 - 19 + 7... </think>\n",
    "<answer> 55 + 36 - 7 - 19 </answer>\"\"\"\n",
    "\n",
    "correct_sample_2 = \"\"\" ... </think>\n",
    "<answer> 55 + 36 - 7 - 19 </answer>\"\"\"\n",
    "\n",
    "wrong_format = \"\"\"User: Using the numbers [19, 36, 55, 7], create an equation that equals 65.\"\"\"\n",
    "\n",
    "wrong_format_2 = \"\"\"To find the equation that equals 79 using the numbers 95, 78, 6, 88, I'll start by adding 88 and 95:                      \n",
    "95 + 88 = 183                                                                                                              \n",
    "Now, let's subtract 104 from 183 to get 79:\n",
    "183 - 104 = 79\n",
    "<think> 183 - 104 = 79 </think><think> 183 - 104 = 79 </think><answer> 183 - 104 = 79 </answer>\"\"\"\n",
    "\n",
    "wrong_result = \"\"\" ... </think>\n",
    "<answer> 55 + 36 - 7 - 18 </answer>\"\"\"\n",
    "\n",
    "\n",
    "test_rewards = format_reward_func(completions=[correct_sample_1, correct_sample_2, wrong_format, wrong_format_2, wrong_result], target=[\"65\", \"65\", \"65\", \"65\", \"65\"], nums=[[19, 36, 55, 7]] * 5)\n",
    "assert test_rewards == [1.0, 1.0, 0.0, 0.0, 1.0], \"Reward function is not working\"\n",
    "test_rewards = equation_reward_func(completions=[correct_sample_1, correct_sample_2, wrong_format, wrong_format_2, wrong_result], target=[\"65\", \"65\", \"65\", \"65\", \"65\"], nums=[[19, 36, 55, 7]] * 5)\n",
    "assert test_rewards == [1.0, 1.0, 0.0, 0.0, 0.0], \"Reward function is not working\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This looks good, now lets define our remaining training parameters, create a trainer and start training. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig\n",
    "\n",
    "# our model we are going to use as policy \n",
    "model_config = ModelConfig(\n",
    "    model_name_or_path=\"Qwen/Qwen2.5-3B-Instruct\",\n",
    "    torch_dtype=\"bfloat16\",\n",
    "    attn_implementation=\"flash_attention_2\",\n",
    "    use_peft=True,\n",
    "    load_in_4bit=True,\n",
    ")\n",
    "\n",
    "# Hyperparameters\n",
    "training_args = GRPOConfig(\n",
    "    output_dir=\"qwen-r1-aha-moment\",\n",
    "    learning_rate=5e-7,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    logging_steps=10,\n",
    "    max_steps=100,\n",
    "    per_device_train_batch_size=1,\n",
    "    gradient_accumulation_steps=1,\n",
    "    gradient_checkpointing=True,\n",
    "    gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
    "    bf16=True,\n",
    "    # GRPO specific parameters\n",
    "    max_prompt_length=256,\n",
    "    max_completion_length=1024, # max length of the generated output for our solution\n",
    "    num_generations=2,\n",
    "    beta=0.001,\n",
    "    \n",
    ")\n",
    "trainer = GRPOTrainer(\n",
    "    model=model_config.model_name_or_path,\n",
    "    reward_funcs=[format_reward_func, equation_reward_func],\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=test_dataset,\n",
    "    peft_config=get_peft_config(model_config),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can start our training by calling the `train` method on the trainer instance. \n",
    "\n",
    "_Note: Reinforcement Training is very slow and compute intensive. Running a single step on 1x L4 with Q-LoRA, Batch size of 1 and only 2 generations per samples takes >20 minutes._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and push the model to the Hub\n",
    "trainer.train()\n",
    "# Save model\n",
    "trainer.save_model(training_args.output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Distributed Training example for GRPO using Deepspeed and vLLM\n",
    "\n",
    "More than 20 minutes per step with only 2 generations per sample is not feasible. We need to scale up our training. Hugging Face TRL added support for distributed training with Deepspeed and using vLLM for faster generation. I preprared a [run_r1_grpo.py](https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/scripts/run_r1_grpo.py) script and a [receipes/grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml](https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/receipes/grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml) config file to run the training. \n",
    "\n",
    "This configuration is tested and validated on a Node with 4x H100 80GBs, where a single step takes around 45-60s, as we can leverage vLLM for generation and DeepSpeed for distributed training. Therefore we need to make sure we correctly set the `num_processes` to the number of GPUs you have - 1 as the last one will be used with vLLM for Generation. If you are using more GPUS you need to change the `vllm_device` in the config file to last index GPU, e.g. if you have 8 GPUs you need to set `vllm_device=7` and your `num_processes` to 7.\n",
    "\n",
    "command to run the training:\n",
    "```bash\n",
    "accelerate launch --num_processes 3 --config_file configs/accelerate_configs/deepspeed_zero3.yaml scripts/run_r1_grpo.py --config receipes/grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml\n",
    "```\n",
    "\n",
    "With the optimized distributed training a single step with 8 generations per sample on 4x H100 80GBs takes around 45-60s. The full training for 450 steps takes around 6 hours. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Results and Training Observations\n",
    "\n",
    "The script saves random completions to the `completion_samples` folder, which you can use to inspect the model's progress. It includes `completion_samples.txt` and `success_completion_samples.txt`. The `completion_samples.txt` includes all completions, while the `success_completion_samples.txt` which correctly solves the equation. Below you can find the interesating training obeserations on how the performance changes over time, as well as the Tensornoard logs and successfull reasoning samples.\n",
    "\n",
    "The model with checkpoints for every 25th step can be found at [philschmid/qwen-2.5-3b-r1-countdown](https://huggingface.co/philschmid/qwen-2.5-3b-r1-countdown). \n",
    "\n",
    "### Hyperparameters\n",
    "\n",
    "I started the experiment using the hyperparameters from the [DeepSeekMath](https://arxiv.org/abs/2402.03300) paper with a learning rate of 1e-6 and a beta (KL coefficient) of 0.04, which led to unstable training runs after around 150 steps. I ran some small ablations and decreased both the learning rate to 5e-7 and the beta to 0.001, based on a test from [OpenRLHF](https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05). I coulnd't test how increasing the `num_generations` from 8 to 64 would affect the training. 64 is the generation value, which was used in the DeepSeekMath paper. All other parameters can be found in the [grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml](https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/receipes/grpo-qwen-2.5-3b-deepseek-r1-countdown.yaml) config file.\n",
    "\n",
    "\n",
    "### Training Observations: \n",
    "\n",
    "* At ~50 steps the model has learned the correct format `<think>...</think>\\n<answer>...</answer>`. \n",
    "* At 100 steps the success rate for solving the equation is around 25%. The model starts to \"reason\" with words see examples below. \n",
    "* At 200 steps the performance seems to converge much slower and we are at ~40% success rate. The model starts to learn a new \"format\" where it solves the equation similar to how you would do it programmatically, by trying different combinations and reviewing the results, see \"Successfull Reasoning Samples between step 200 and 450\".\n",
    "* At 450 steps we have 50% success rate for solving the equation. The performance still improves slowly and the model kept its new format form from step 200.\n",
    "\n",
    "I have 3 potential assumptions why the model shifts from \"word reasoning\" to \"programmatic execution\":\n",
    "1. Qwen 2.5 3B is not strong enough or to small, Deepseek mentions that you need a very strong base model. \n",
    "2. The reward functions are not good enough defined and it reward hacks the model to solve the equation. We could try to force it to use words, e.g. having a number to word frequency condition. (We don't know much about the reward functions from Deepseek)\n",
    "3. Only training on the Countdown Game tasks might force the model naturally to learn the most effective way to solve the equation as no other formats are required. \n",
    "4. We didn't train the model long enough, in the R1 paper they showed visuals of training for over 8000 steps. \n",
    "\n",
    "\n",
    "#### Tensorboard Results\n",
    "\n",
    "![Qwen R1 Tensorboard results](../assets/tensorboard-r1.png)\n",
    "\n",
    "\n",
    "#### Successfull Reasoning Samples up to step 200\n",
    "\n",
    "```xml\n",
    "<think> First, I need to think about the numbers we have: 14, 70, and 80. To get 85, I think 80 might be a good starting point because it's right in between 85. If I subtract or add 80, it won't lead to 85 easily. Next, 70 seems to be a good intermediate number, especially to use in conjunction with 14. \n",
    "Now, I can think about different combinations using addition, subtraction and multiplication/division. If I take 80 - 70 = 10, which is not needed. So, I need to find a way to use 14. If I subtract 14 from 80, I get 66, which still isn't 85. If I subtract 70 from 80, I get 10 again, this isn't constructive yet. \n",
    "Multiplication could create numbers larger than the target 85. Let's consider division: If I divide 70 by 14, I get 5. Then if I add this result to 80, I get 85. This combination seems promising: 80 - (70 / 14) = 80 - 5 = 75, this still isn't quite right. Let me correct my approach: Let's use division and then addition. If I take 80 / (70 / 14) → 80 / 5 = 16, and then 80 - 16 = 64, still incorrect. Let me simplify: 80 / (70 / 14) → 80 / 5 = 16, and then 80 - 16 gives 64 again, let me correct: 80 - (70 / 14) → 80 - 5 = 75, but I need 85. Thus I will use the combinations in a slightly different order to insert numbers: I get 80 + (70 / 14) = 80 + 5 = 85</think>\n",
    "<answer> 80 + (70 / 14) </answer>\n",
    "```\n",
    "\n",
    "```xml\n",
    "<think> First, I need to consider how I can use the numbers 1, 21, and 90 to reach 68. I can try different combinations of operations. Let's try to use subtraction and addition. If I subtract 90 from 21, I get -69, which is too far from 68. Now, if I add 1 to 90, I get 91, which is also not right. Another way is to see what happens if I subtract 21 from 90, which gives me 69. This is closer but still not 68. Let me use multiplication instead of subtraction. 90 * 1 = 90. Now, if I try to reduce this slightly by subtracting 21, I am not getting close to 68.\n",
    "Now, let's try another multiplication. I'll think of 1 * (larger value) + (smaller value). Let's try 1 * 90 + 21. This evaluates to 90 + 21 = 111, which is too much. Instead, I will try: 90 / 1 + 21 - 7. But this gives a division by 7, which seems complex. Let's re-evaluate: I'll use 1 to divide 21: 21 / 1 = 21. Now, subtract 21 from 90, and add 1, we have 90 - 21 = 69, using one...\n",
    "Hey, wait on re-evaluating I will subtract 21 from 90, and include 1 to complete the solution: 90 - (21 + 1) which would be 68.</think>\n",
    "<answer> 90 - (21 + 1) </answer>\n",
    "```\n",
    "\n",
    "\n",
    "```xml\n",
    "<think> First, I need to consider the numbers we have: 59, 80, and 66. I need to use these numbers to get 45. Let's try to use subtraction first since it can quickly reduce large numbers.\n",
    "Let's see if we can get close by subtracting or adding these numbers. If I subtract 66 from 80, I get 14. Then I need to make this 14 equal 45 when combined with 59. That gives me 14 + 59 = 73, which is too high.\n",
    "Now let's try adding numbers. If I add 59 and 66, I get 125. Then I would need to subtract 80 from 125 to get 45, but 80 is too large to fit well with this.\n",
    "Let me try a different approach. What if I subtract 66 from 80 first and then add 59? That would be 14 + 59 = 73, which is still not 45.\n",
    "Hmm, I need to think of another combination. What if I start with 59 and then add 80 and subtract 66? That would be (59 + 80) - 66, let me calculate it. That equals 139 - 66, which is 73.\n",
    "That's still not 45, so let me try rearranging it slightly: maybe 80 - 66 + 59 would be closer to 45. Calculating again: 80 - 66 is 14, then adding 59 gives 73.\n",
    "This way doesn't work either. Let's flip this: maybe (59 + 66) - 80. Let me check: that's 125 - 80, which equals 45.</think>\n",
    "<answer>(59 + 66) - 80</answer>\n",
    "```\n",
    "\n",
    "```xml\n",
    "<think> First, I need to find a combination of the numbers 61, 4, 94, and 21 that can be manipulated using basic arithmetic operations to equal 50. One way to approach this is to look for numbers that can be used to reach a target number close to 50 and then adjust with the remaining numbers.\n",
    "Let's try combining 94 and 21 first:\n",
    "94 + 21 = 115\n",
    "Now, we need to get from 115 to 50. We can try subtracting 65:\n",
    "115 - 65 = 50\n",
    "Now we need to find a way to get 65 using the remaining numbers 61 and 4. We can try:\n",
    "61 - 4 = 57\n",
    "Now we need to get from 57 to 65. Adding 8 would give us that:\n",
    "57 + 8 = 65\n",
    "So, we need to find a way to get 8 using 61 and 4. We can try:\n",
    "61 - 94 = -36\n",
    "This is negative, so let's try dividing:\n",
    "61 / 94 = 0.648987...\n",
    "This is not helping, so let's try multiplication and division in a different way:\n",
    "(61 - 4) * (94 / 21) = 57 * 4.476... = 254.4\n",
    "This is not working either. Let's try a different approach:\n",
    "(94 - 61) = 33\n",
    "Now we have:\n",
    "33 + 21 = 54\n",
    "We need to adjust from 54 to 50, so we can:\n",
    "54 - 4 = 50\n",
    "So, the equation becomes:\n",
    "(94 - 61) + (21 - 4) = 50\n",
    "</think>\n",
    "<answer>(94 - 61) + (21 - 4)</answer>\n",
    "```\n",
    "\n",
    "```xml\n",
    "<think> First, I'll consider the possibility of directly subtracting two numbers to get close to 31. If I subtract 57 from 63, I get 7. Now, if I subtract this 7 from 25, I get 25 - 7 = 18. Far from 31, so this doesn't work.\n",
    "Next, let's try adding and subtracting. Dividing 63 by 25 will give a decimal, so that's unlikely to work. Now, let's see if adding and subtracting in another order could work. For example, let's try 63 - 25, and then we see this gets us 38. Now, we need a way to get to 31 from here. Subtracting 7 from 38 would work, but we can't figure out how to create that 7 from the given numbers.\n",
    "Hmm, let me try to incorporate all numbers by adding and subtracting. If I try 63 + 25, I get 88. This is quite far from 31, but if 88 - 57, I get 31 exactly!\n",
    "This seems to be the equation: 63 + 25 - 57 equals 31.\n",
    "</think>\n",
    "<answer> 63 + 25 - 57 </answer>\n",
    "```\n",
    "\n",
    "#### Successfull Reasoning Samples between step 200 and 450\n",
    "\n",
    "```xml\n",
    "<think> We need to use the numbers 39, 16, and 55 with basic arithmetic operations to get 78. Let's try different combinations:\n",
    "- 55 + 39 - 16 = 84 - 16 = 68 (too low)\n",
    "- 55 + 16 - 39 = 71 - 39 = 32 (too low)\n",
    "- 55 - 16 + 39 = 39 + 39 = 78 (this works)\n",
    "So, the equation is 55 - 16 + 39 = 78. </think>\n",
    "<answer> 55 - 16 + 39 </answer>\n",
    "```\n",
    "\n",
    "```xml\n",
    "<think> We need to use the numbers 21, 5, and 40 with basic arithmetic operations to get 24. Let's try different combinations:\n",
    "- 21 - 5 - 40 = -15\n",
    "- 21 - 5 + 40 = 56\n",
    "- 21 + 5 - 40 = -14\n",
    "- 21 * 5 / 40 = 2.1 (not an integer, let's try another)\n",
    "- 40 * 5 / 21 = 9.52 (not 24)\n",
    "- 40 / 5 - 21 = 9 - 21 = -12\n",
    "- 5 * 40 / 21 = 100 / 21 ~= 4.76 (not 24)\n",
    "- 5 + 40 - 21 = 24 ! This works.\n",
    "</think>\n",
    "<answer> 5 + 40 - 21 </answer>\n",
    "```\n",
    "\n",
    "```xml\n",
    "<think> We need to use the numbers 75, 99, 62, and 34 with basic arithmetic operations to get 72. Let's try different combinations:\n",
    "- 99 - 75 + 62 - 34 = 92 - 34 = 58 (too high)\n",
    "- 99 - 75 - 62 + 34 = 24 - 62 + 34 = -38 + 34 = -4 (too low)\n",
    "- 99 + 75 - 62 - 34 = 174 - 96 = 78 (too high)\n",
    "- 99 + 75 - 62 + 34 = 174 + 34 - 62 = 208 - 62 = 146 (too high)\n",
    "- 99 - 75 + 62 + 34 = 24 + 96 = 120 (too high)\n",
    "- 75 + 99 - 62 - 34 = 174 - 96 = 78 (too high)\n",
    "- 75 + 99 - 62 + 34 = 174 + 34 - 62 = 208 - 62 = 146 (too high)\n",
    "- 75 + 62 - 99 + 34 = 137 - 99 + 34 = 38 + 34 = 72\n",
    "So, 75 + 62 - 99 + 34 equals 72.\n",
    "</think>\n",
    "<answer> 75 + 62 - 99 + 34 </answer>\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "The release of DeepSeek R1 and its research paper might be breakpoint for the open-science and open-source development. Just a week after DeepSeek release, we've been able to reproduce a simple version of R1 learned \"reasoning\" using GRPO and the Countdown Game. While our implementation focuses on a specific task rather than general reasoning and convergence into a very specific \"reasoning\" format, it shows that the method is working. \n",
    "\n",
    "In our mini R1 experiment we used GRPO, with two rule-based reward but already required significant compute: 4 H100 GPUs running for 6 hours to complete just 450 training steps on a 3B parameter model. This gives us an idea of the compute needs that you will need to scale Reinforcement Learning. Deepseek ran a 671B model for over 8000 steps and they probably did many ablations.\n",
    "\n",
    "Looking in to 2025, it's clear that we are on the cusp of even more significant progress. RL will become more accessible and user-friendly, more researchers and developers will explore its potential, but also require amount of more compute as before and compared to supervised fine-tuning. \n",
    "\n",
    "I am excited for 2025. If you are have any question or ideas feel free to reach out to me. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
